import java.util.*;

public class Partition {
    public ListNode partition(ListNode head, int x) {
        ListNode list1 = new ListNode(0);
        ListNode list2 = new ListNode(0);
        ListNode ret = list1;
        ListNode s = list2;
        while(head != null) {
            if(head.val < x){
                list1.next = head;
                list1 = list1.next;
            }else {
                list2.next = head;
                list2 = list2.next;
            }
            head = head.next;
        }

        // if(list1 == null) {
        //     return s.next;
        // }
        // if(list2 == null) {
        //     return ret.next;
        // }

        list2.next=null;
        list1.next = s.next;
        return ret.next;
    }
}
